{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-20T10:08:49.063045Z",
     "start_time": "2020-10-20T10:08:48.091607Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-20T10:28:30.464930Z",
     "start_time": "2020-10-20T10:28:30.457197Z"
    }
   },
   "outputs": [],
   "source": [
    "batch_size=1\n",
    "seq_len=3\n",
    "input_size=4\n",
    "hidden_size=2\n",
    "\n",
    "cell=torch.nn.RNNCell(input_size=input_size,hidden_size=hidden_size)\n",
    "\n",
    "dataset=torch.randn(seq_len,batch_size,input_size)\n",
    "hidden=torch.zeros(batch_size,hidden_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-20T10:33:32.931989Z",
     "start_time": "2020-10-20T10:33:32.881845Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================== 0 ====================\n",
      "Input size: torch.Size([1, 4])\n",
      "output size: torch.Size([1, 2])\n",
      "tensor([[ 0.2333, -0.2216]], grad_fn=<TanhBackward>)\n",
      "==================== 1 ====================\n",
      "Input size: torch.Size([1, 4])\n",
      "output size: torch.Size([1, 2])\n",
      "tensor([[-0.9512,  0.7459]], grad_fn=<TanhBackward>)\n",
      "==================== 2 ====================\n",
      "Input size: torch.Size([1, 4])\n",
      "output size: torch.Size([1, 2])\n",
      "tensor([[ 0.7577, -0.6482]], grad_fn=<TanhBackward>)\n"
     ]
    }
   ],
   "source": [
    "for idx,input in enumerate(dataset):\n",
    "    print('=' *20 , idx,'='*20)\n",
    "    print('Input size:',input.shape)\n",
    "    \n",
    "    hidden=cell(input,hidden)\n",
    "    print('output size:',hidden.shape)\n",
    "    print(hidden)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-20T10:56:20.349485Z",
     "start_time": "2020-10-20T10:56:20.341447Z"
    }
   },
   "outputs": [],
   "source": [
    "idx2char=['e','h','l','o']\n",
    "x_data=[1,0,2,2,3]\n",
    "y_data=[3,1,2,3,2]\n",
    "\n",
    "one_hot_lookup=[[1,0,0,0],\n",
    "                [0,1,0,0],\n",
    "                [0,0,1,0],\n",
    "                [0,0,0,1]]\n",
    "x_one_hot=[one_hot_lookup[x] for x in x_data]\n",
    "\n",
    "inputs=torch.Tensor(x_one_hot).view(-1,batch_size,input_size)\n",
    "labels=torch.LongTensor(y_data).view(-1,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-20T11:04:51.815965Z",
     "start_time": "2020-10-20T11:04:51.805516Z"
    }
   },
   "outputs": [],
   "source": [
    "class Model(torch.nn.Module):\n",
    "    def __init__(self,input_size,hidden_size,batch_size):\n",
    "        super(Model,self).__init__()\n",
    "        self.batch_size=batch_size\n",
    "        self.input_size=input_size\n",
    "        self.hidden_size=hidden_size\n",
    "        self.rnncell=torch.nn.RNNCell(input_size=self.input_size,hidden_size=self.hidden_size)\n",
    "    def forward(self,input,hidden):\n",
    "        hidden=self.rnncell(input,hidden)\n",
    "        return hidden\n",
    "    def init_hidden(self):\n",
    "        return torch.zeros(self.batch_size,self.hidden_size)\n",
    "    \n",
    "net=Model(input_size,hidden_size,batch_size)\n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-20T11:17:21.984157Z",
     "start_time": "2020-10-20T11:17:21.951711Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicted string= "
     ]
    },
    {
     "ename": "IndexError",
     "evalue": "Target 3 is out of bounds.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mIndexError\u001b[0m                             Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-9-16dfde746d2a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      9\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0minput_\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mlabel\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     10\u001b[0m         \u001b[0mhidden\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnet\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mhidden\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m         \u001b[0mloss\u001b[0m\u001b[0;34m+=\u001b[0m\u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhidden\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mlabel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     12\u001b[0m         \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mhidden\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m         \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0midx2char\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mend\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m' '\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/mytorch/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    720\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    721\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 722\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    723\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    724\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/mytorch/lib/python3.8/site-packages/torch/nn/modules/loss.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input, target)\u001b[0m\n\u001b[1;32m    945\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    946\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 947\u001b[0;31m         return F.cross_entropy(input, target, weight=self.weight,\n\u001b[0m\u001b[1;32m    948\u001b[0m                                ignore_index=self.ignore_index, reduction=self.reduction)\n\u001b[1;32m    949\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/mytorch/lib/python3.8/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mcross_entropy\u001b[0;34m(input, target, weight, size_average, ignore_index, reduce, reduction)\u001b[0m\n\u001b[1;32m   2420\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0msize_average\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mreduce\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2421\u001b[0m         \u001b[0mreduction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_Reduction\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlegacy_get_string\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msize_average\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2422\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0mnll_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlog_softmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mignore_index\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduction\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   2423\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2424\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/mytorch/lib/python3.8/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mnll_loss\u001b[0;34m(input, target, weight, size_average, ignore_index, reduce, reduction)\u001b[0m\n\u001b[1;32m   2216\u001b[0m                          .format(input.size(0), target.size(0)))\n\u001b[1;32m   2217\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mdim\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2218\u001b[0;31m         \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_nn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnll_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_Reduction\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_enum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreduction\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mignore_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   2219\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0mdim\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2220\u001b[0m         \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_nn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnll_loss2d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_Reduction\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_enum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreduction\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mignore_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mIndexError\u001b[0m: Target 3 is out of bounds."
     ]
    }
   ],
   "source": [
    "criterion=torch.nn.CrossEntropyLoss()\n",
    "optimizer=torch.optim.Adam(net.parameters(),lr=0.1)\n",
    "\n",
    "for epoch in range(15):\n",
    "    loss=0\n",
    "    optimizer.zero_grad()\n",
    "    hidden=net.init_hidden()\n",
    "    print('Predicted string=',end=' ')\n",
    "    for input_,label in zip(inputs,labels):\n",
    "        hidden=net(input_,hidden)\n",
    "        loss+=criterion(hidden,label)\n",
    "        _,idx=hidden.max(dim=1)\n",
    "        print(idx2char[idx.item()],end=' ')\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    print(',Epoch [%d/15]  loss=%.4f'% (epoch+1,loss.item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
